-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[NVPTX] Customize getScalarizationOverhead #128077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
dca12e5 to
ec973a7
Compare
| if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>)) | ||
| return TTI::TCC_Free; | ||
| auto *VecTy = getWidenedType(ScalarTy, VL.size()); | ||
| if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be happy to hear alternatives to use the existing getVectorInstrCost interface better. I had tried pattern matching InsertElt(InsertElt(poison, a, 0), b, 1) but found that the second insert is also called with poison, so we can't pattern match it. e.g. here both calls are passed poison:
llvm-project/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
Lines 5883 to 5888 in 6bfedfa
| InstructionCost InsertFirstCost = TTI->getVectorInstrCost( | |
| Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, 0, | |
| PoisonValue::get(Ty), *It); | |
| InstructionCost InsertIdxCost = TTI->getVectorInstrCost( | |
| Instruction::InsertElement, Ty, TTI::TCK_RecipThroughput, Idx, | |
| PoisonValue::get(Ty), *It); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should not be here. Rework getScalarizationOverhead instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's exactly what I was looking for!
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-vectorizers Author: None (peterbell10) ChangesWe've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single To fix this I add Full diff: https://github.com/llvm/llvm-project/pull/128077.diff 7 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9048481b49189..c3a17a6fdb29d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1479,6 +1479,12 @@ class TargetTransformInfo {
InstructionCost getInsertExtractValueCost(unsigned Opcode,
TTI::TargetCostKind CostKind) const;
+ /// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
+ /// inferred from insert element and shuffle ops.
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TargetCostKind CostKind) const;
+
/// \return The cost of replication shuffle of \p VF elements typed \p EltTy
/// \p ReplicationFactor times.
///
@@ -2224,6 +2230,10 @@ class TargetTransformInfo::Concept {
TTI::TargetCostKind CostKind,
unsigned Index) = 0;
+ virtual std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TargetCostKind CostKind) = 0;
+
virtual InstructionCost
getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
const APInt &DemandedDstElts,
@@ -2952,6 +2962,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned Index) override {
return Impl.getVectorInstrCost(I, Val, CostKind, Index);
}
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) override {
+ return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
+ }
+
InstructionCost
getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
const APInt &DemandedDstElts,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a8d6dd18266bb..f7ef03bea2548 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -739,6 +739,12 @@ class TargetTransformInfoImplBase {
return 1;
}
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) const {
+ return std::nullopt;
+ }
+
unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
const APInt &DemandedDstElts,
TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 032c7d7b5159e..a58c8dcee49c2 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1373,6 +1373,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
Op1);
}
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) {
+ return std::nullopt;
+ }
+
InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
int VF,
const APInt &DemandedDstElts,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1ca9a16b18112..69b8f6706b563 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1123,6 +1123,13 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
return Cost;
}
+std::optional<InstructionCost>
+TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
+ ArrayRef<Value *> Operands,
+ TargetCostKind CostKind) const {
+ return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
+}
+
InstructionCost TargetTransformInfo::getReplicationShuffleCost(
Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b0a846a9c7f96..cbe94a2e82279 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -16,8 +16,9 @@
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
#define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
-#include "NVPTXTargetMachine.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
+#include "NVPTXTargetMachine.h"
+#include "NVPTXUtilities.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/TargetLowering.h"
@@ -100,6 +101,26 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) {
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return std::nullopt;
+ auto VT = getTLI()->getValueType(DL, VecTy);
+ if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
+ return TTI::TCC_Free;
+ if (Isv2x16VT(VT))
+ return 1; // Single vector mov
+ if (VT == MVT::v4i8) {
+ InstructionCost Cost = 3; // 3 x PRMT
+ for (auto *Op : Operands)
+ if (!isa<Constant>(Op))
+ Cost += 1; // zext operand to i32
+ return Cost;
+ }
+ return std::nullopt;
+ }
+
void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
TTI::UnrollingPreferences &UP,
OptimizationRemarkEmitter *ORE);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f2aa0e8328585..cfe7bbc641906 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10203,6 +10203,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
return TTI::TCC_Free;
auto *VecTy = getWidenedType(ScalarTy, VL.size());
+ if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
+ Cost.has_value())
+ return *Cost;
InstructionCost GatherCost = 0;
SmallVector<Value *> Gathers(VL);
if (!Root && isSplat(VL)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
index 13773bf901b9b..c74909d7ceb2a 100644
--- a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
+++ b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
@@ -1,59 +1,123 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
-; CHECK-LABEL: @fusion(
-; CHECK-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
-; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
-; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
-; CHECK-NEXT: store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
-; CHECK-NEXT: ret void
+; VECTOR-LABEL: @fusion(
+; VECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; VECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; VECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; VECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; VECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT: [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
+; VECTOR-NEXT: [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
+; VECTOR-NEXT: [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
+; VECTOR-NEXT: store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
+; VECTOR-NEXT: ret void
;
; NOVECTOR-LABEL: @fusion(
-; NOVECTOR-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; NOVECTOR-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; NOVECTOR-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; NOVECTOR-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; NOVECTOR-NEXT: [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
-; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
+; NOVECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; NOVECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; NOVECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; NOVECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; NOVECTOR-NEXT: [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
+; NOVECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT: [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
+; NOVECTOR-NEXT: [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
+; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
+; NOVECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT: store half [[TMP9]], ptr [[TMP6]], align 8
+; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
+; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
; NOVECTOR-NEXT: [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
; NOVECTOR-NEXT: [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
-; NOVECTOR-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP16]], align 8
-; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
-; NOVECTOR-NEXT: [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
-; NOVECTOR-NEXT: [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
-; NOVECTOR-NEXT: [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
-; NOVECTOR-NEXT: [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
-; NOVECTOR-NEXT: store half [[TMP20]], ptr [[TMP21]], align 2
+; NOVECTOR-NEXT: [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
+; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP15]], align 2
; NOVECTOR-NEXT: ret void
;
- %tmp = shl nuw nsw i32 %arg2, 6
- %tmp4 = or i32 %tmp, %arg3
- %tmp5 = shl nuw nsw i32 %tmp4, 2
- %tmp6 = zext i32 %tmp5 to i64
- %tmp7 = or disjoint i64 %tmp6, 1
- %tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
- %tmp12 = load half, ptr %tmp11, align 8
- %tmp13 = fmul fast half %tmp12, 0xH5380
- %tmp14 = fadd fast half %tmp13, 0xH57F0
- %tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
- store half %tmp14, ptr %tmp16, align 8
- %tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
- %tmp18 = load half, ptr %tmp17, align 2
- %tmp19 = fmul fast half %tmp18, 0xH5380
- %tmp20 = fadd fast half %tmp19, 0xH57F0
- %tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
- store half %tmp20, ptr %tmp21, align 2
+ %1 = shl nuw nsw i32 %arg2, 6
+ %4 = or i32 %1, %arg3
+ %5 = shl nuw nsw i32 %4, 2
+ %6 = zext i32 %5 to i64
+ %7 = or disjoint i64 %6, 1
+ %11 = getelementptr inbounds half, ptr %arg1, i64 %6
+ %12 = load half, ptr %11, align 8
+ %13 = fmul fast half %12, 0xH5380
+ %14 = fadd fast half %13, 0xH57F0
+ %16 = getelementptr inbounds half, ptr %arg, i64 %6
+ store half %14, ptr %16, align 8
+ %17 = getelementptr inbounds half, ptr %arg1, i64 %7
+ %18 = load half, ptr %17, align 2
+ %19 = fmul fast half %18, 0xH5380
+ %20 = fadd fast half %19, 0xH57F0
+ %21 = getelementptr inbounds half, ptr %arg, i64 %7
+ store half %20, ptr %21, align 2
ret void
}
+define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
+; VECTOR-LABEL: @add_f16(
+; VECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; VECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; VECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; VECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; VECTOR-NEXT: [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
+; VECTOR-NEXT: [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
+; VECTOR-NEXT: [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
+; VECTOR-NEXT: [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
+; VECTOR-NEXT: [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
+; VECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; VECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; VECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; VECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; VECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; VECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; VECTOR-NEXT: ret void
+;
+; NOVECTOR-LABEL: @add_f16(
+; NOVECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; NOVECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; NOVECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; NOVECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; NOVECTOR-NEXT: [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
+; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
+; NOVECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; NOVECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; NOVECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; NOVECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; NOVECTOR-NEXT: [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
+; NOVECTOR-NEXT: [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
+; NOVECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; NOVECTOR-NEXT: ret void
+;
+ %5 = extractvalue { half, half } %1, 0
+ %6 = extractvalue { half, half } %1, 1
+ %7 = extractvalue { half, half } %2, 0
+ %8 = extractvalue { half, half } %2, 1
+ %9 = fadd half %5, %7
+ %10 = fadd half %6, %8
+ %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+ %12 = shl i32 %11, 1
+ %13 = and i32 %12, 62
+ %14 = zext nneg i32 %13 to i64
+ %15 = getelementptr half, ptr addrspace(1) %0, i64 %14
+ %18 = insertelement <2 x half> poison, half %9, i64 0
+ %19 = insertelement <2 x half> %18, half %10, i64 1
+ store <2 x half> %19, ptr addrspace(1) %15, align 4
+ ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
attributes #0 = { nounwind }
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; SM50: {{.*}}
+; SM70: {{.*}}
+; SM80: {{.*}}
+; SM90: {{.*}}
|
|
@llvm/pr-subscribers-backend-nvptx Author: None (peterbell10) ChangesWe've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single To fix this I add Full diff: https://github.com/llvm/llvm-project/pull/128077.diff 7 Files Affected:
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 9048481b49189..c3a17a6fdb29d 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -1479,6 +1479,12 @@ class TargetTransformInfo {
InstructionCost getInsertExtractValueCost(unsigned Opcode,
TTI::TargetCostKind CostKind) const;
+ /// \return The cost of ISD::BUILD_VECTOR, or nullopt if the cost should be
+ /// inferred from insert element and shuffle ops.
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TargetCostKind CostKind) const;
+
/// \return The cost of replication shuffle of \p VF elements typed \p EltTy
/// \p ReplicationFactor times.
///
@@ -2224,6 +2230,10 @@ class TargetTransformInfo::Concept {
TTI::TargetCostKind CostKind,
unsigned Index) = 0;
+ virtual std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TargetCostKind CostKind) = 0;
+
virtual InstructionCost
getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
const APInt &DemandedDstElts,
@@ -2952,6 +2962,12 @@ class TargetTransformInfo::Model final : public TargetTransformInfo::Concept {
unsigned Index) override {
return Impl.getVectorInstrCost(I, Val, CostKind, Index);
}
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) override {
+ return Impl.getBuildVectorCost(VecTy, Operands, CostKind);
+ }
+
InstructionCost
getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
const APInt &DemandedDstElts,
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
index a8d6dd18266bb..f7ef03bea2548 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
@@ -739,6 +739,12 @@ class TargetTransformInfoImplBase {
return 1;
}
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *Val, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) const {
+ return std::nullopt;
+ }
+
unsigned getReplicationShuffleCost(Type *EltTy, int ReplicationFactor, int VF,
const APInt &DemandedDstElts,
TTI::TargetCostKind CostKind) {
diff --git a/llvm/include/llvm/CodeGen/BasicTTIImpl.h b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
index 032c7d7b5159e..a58c8dcee49c2 100644
--- a/llvm/include/llvm/CodeGen/BasicTTIImpl.h
+++ b/llvm/include/llvm/CodeGen/BasicTTIImpl.h
@@ -1373,6 +1373,12 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
Op1);
}
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) {
+ return std::nullopt;
+ }
+
InstructionCost getReplicationShuffleCost(Type *EltTy, int ReplicationFactor,
int VF,
const APInt &DemandedDstElts,
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 1ca9a16b18112..69b8f6706b563 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -1123,6 +1123,13 @@ InstructionCost TargetTransformInfo::getInsertExtractValueCost(
return Cost;
}
+std::optional<InstructionCost>
+TargetTransformInfo::getBuildVectorCost(VectorType *VecTy,
+ ArrayRef<Value *> Operands,
+ TargetCostKind CostKind) const {
+ return TTIImpl->getBuildVectorCost(VecTy, Operands, CostKind);
+}
+
InstructionCost TargetTransformInfo::getReplicationShuffleCost(
Type *EltTy, int ReplicationFactor, int VF, const APInt &DemandedDstElts,
TTI::TargetCostKind CostKind) const {
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
index b0a846a9c7f96..cbe94a2e82279 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h
@@ -16,8 +16,9 @@
#ifndef LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
#define LLVM_LIB_TARGET_NVPTX_NVPTXTARGETTRANSFORMINFO_H
-#include "NVPTXTargetMachine.h"
#include "MCTargetDesc/NVPTXBaseInfo.h"
+#include "NVPTXTargetMachine.h"
+#include "NVPTXUtilities.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/CodeGen/BasicTTIImpl.h"
#include "llvm/CodeGen/TargetLowering.h"
@@ -100,6 +101,26 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
TTI::OperandValueInfo Op2Info = {TTI::OK_AnyValue, TTI::OP_None},
ArrayRef<const Value *> Args = {}, const Instruction *CxtI = nullptr);
+ std::optional<InstructionCost>
+ getBuildVectorCost(VectorType *VecTy, ArrayRef<Value *> Operands,
+ TTI::TargetCostKind CostKind) {
+ if (CostKind != TTI::TCK_RecipThroughput)
+ return std::nullopt;
+ auto VT = getTLI()->getValueType(DL, VecTy);
+ if (all_of(Operands, [](Value *Op) { return isa<Constant>(Op); }))
+ return TTI::TCC_Free;
+ if (Isv2x16VT(VT))
+ return 1; // Single vector mov
+ if (VT == MVT::v4i8) {
+ InstructionCost Cost = 3; // 3 x PRMT
+ for (auto *Op : Operands)
+ if (!isa<Constant>(Op))
+ Cost += 1; // zext operand to i32
+ return Cost;
+ }
+ return std::nullopt;
+ }
+
void getUnrollingPreferences(Loop *L, ScalarEvolution &SE,
TTI::UnrollingPreferences &UP,
OptimizationRemarkEmitter *ORE);
diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
index f2aa0e8328585..cfe7bbc641906 100644
--- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
+++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp
@@ -10203,6 +10203,9 @@ class BoUpSLP::ShuffleCostEstimator : public BaseShuffleAnalysis {
if ((!Root && allConstant(VL)) || all_of(VL, IsaPred<UndefValue>))
return TTI::TCC_Free;
auto *VecTy = getWidenedType(ScalarTy, VL.size());
+ if (auto Cost = TTI.getBuildVectorCost(VecTy, VL, CostKind);
+ Cost.has_value())
+ return *Cost;
InstructionCost GatherCost = 0;
SmallVector<Value *> Gathers(VL);
if (!Root && isSplat(VL)) {
diff --git a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
index 13773bf901b9b..c74909d7ceb2a 100644
--- a/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
+++ b/llvm/test/Transforms/SLPVectorizer/NVPTX/v2f16.ll
@@ -1,59 +1,123 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s
-; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_40 | FileCheck %s -check-prefix=NOVECTOR
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_90 | FileCheck %s -check-prefixes=VECTOR,SM90
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_80 | FileCheck %s -check-prefixes=VECTOR,SM80
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_70 | FileCheck %s -check-prefixes=VECTOR,SM70
+; RUN: opt < %s -passes=slp-vectorizer -S -mtriple=nvptx64-nvidia-cuda -mcpu=sm_50 | FileCheck %s -check-prefixes=NOVECTOR,SM50
define void @fusion(ptr noalias nocapture align 256 dereferenceable(19267584) %arg, ptr noalias nocapture readonly align 256 dereferenceable(19267584) %arg1, i32 %arg2, i32 %arg3) local_unnamed_addr #0 {
-; CHECK-LABEL: @fusion(
-; CHECK-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; CHECK-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; CHECK-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; CHECK-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; CHECK-NEXT: [[TMP1:%.*]] = load <2 x half>, ptr [[TMP11]], align 8
-; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x half> [[TMP1]], splat (half 0xH5380)
-; CHECK-NEXT: [[TMP3:%.*]] = fadd fast <2 x half> [[TMP2]], splat (half 0xH57F0)
-; CHECK-NEXT: store <2 x half> [[TMP3]], ptr [[TMP16]], align 8
-; CHECK-NEXT: ret void
+; VECTOR-LABEL: @fusion(
+; VECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; VECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; VECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; VECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; VECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; VECTOR-NEXT: [[TMP7:%.*]] = load <2 x half>, ptr [[TMP5]], align 8
+; VECTOR-NEXT: [[TMP8:%.*]] = fmul fast <2 x half> [[TMP7]], splat (half 0xH5380)
+; VECTOR-NEXT: [[TMP9:%.*]] = fadd fast <2 x half> [[TMP8]], splat (half 0xH57F0)
+; VECTOR-NEXT: store <2 x half> [[TMP9]], ptr [[TMP6]], align 8
+; VECTOR-NEXT: ret void
;
; NOVECTOR-LABEL: @fusion(
-; NOVECTOR-NEXT: [[TMP:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
-; NOVECTOR-NEXT: [[TMP4:%.*]] = or i32 [[TMP]], [[ARG3:%.*]]
-; NOVECTOR-NEXT: [[TMP5:%.*]] = shl nuw nsw i32 [[TMP4]], 2
-; NOVECTOR-NEXT: [[TMP6:%.*]] = zext i32 [[TMP5]] to i64
-; NOVECTOR-NEXT: [[TMP7:%.*]] = or disjoint i64 [[TMP6]], 1
-; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 8
+; NOVECTOR-NEXT: [[TMP1:%.*]] = shl nuw nsw i32 [[ARG2:%.*]], 6
+; NOVECTOR-NEXT: [[TMP2:%.*]] = or i32 [[TMP1]], [[ARG3:%.*]]
+; NOVECTOR-NEXT: [[TMP3:%.*]] = shl nuw nsw i32 [[TMP2]], 2
+; NOVECTOR-NEXT: [[TMP4:%.*]] = zext i32 [[TMP3]] to i64
+; NOVECTOR-NEXT: [[TMP10:%.*]] = or disjoint i64 [[TMP4]], 1
+; NOVECTOR-NEXT: [[TMP5:%.*]] = getelementptr inbounds half, ptr [[ARG1:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT: [[TMP7:%.*]] = load half, ptr [[TMP5]], align 8
+; NOVECTOR-NEXT: [[TMP8:%.*]] = fmul fast half [[TMP7]], 0xH5380
+; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd fast half [[TMP8]], 0xH57F0
+; NOVECTOR-NEXT: [[TMP6:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP4]]
+; NOVECTOR-NEXT: store half [[TMP9]], ptr [[TMP6]], align 8
+; NOVECTOR-NEXT: [[TMP11:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP10]]
+; NOVECTOR-NEXT: [[TMP12:%.*]] = load half, ptr [[TMP11]], align 2
; NOVECTOR-NEXT: [[TMP13:%.*]] = fmul fast half [[TMP12]], 0xH5380
; NOVECTOR-NEXT: [[TMP14:%.*]] = fadd fast half [[TMP13]], 0xH57F0
-; NOVECTOR-NEXT: [[TMP16:%.*]] = getelementptr inbounds half, ptr [[ARG:%.*]], i64 [[TMP6]]
-; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP16]], align 8
-; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr inbounds half, ptr [[ARG1]], i64 [[TMP7]]
-; NOVECTOR-NEXT: [[TMP18:%.*]] = load half, ptr [[TMP17]], align 2
-; NOVECTOR-NEXT: [[TMP19:%.*]] = fmul fast half [[TMP18]], 0xH5380
-; NOVECTOR-NEXT: [[TMP20:%.*]] = fadd fast half [[TMP19]], 0xH57F0
-; NOVECTOR-NEXT: [[TMP21:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP7]]
-; NOVECTOR-NEXT: store half [[TMP20]], ptr [[TMP21]], align 2
+; NOVECTOR-NEXT: [[TMP15:%.*]] = getelementptr inbounds half, ptr [[ARG]], i64 [[TMP10]]
+; NOVECTOR-NEXT: store half [[TMP14]], ptr [[TMP15]], align 2
; NOVECTOR-NEXT: ret void
;
- %tmp = shl nuw nsw i32 %arg2, 6
- %tmp4 = or i32 %tmp, %arg3
- %tmp5 = shl nuw nsw i32 %tmp4, 2
- %tmp6 = zext i32 %tmp5 to i64
- %tmp7 = or disjoint i64 %tmp6, 1
- %tmp11 = getelementptr inbounds half, ptr %arg1, i64 %tmp6
- %tmp12 = load half, ptr %tmp11, align 8
- %tmp13 = fmul fast half %tmp12, 0xH5380
- %tmp14 = fadd fast half %tmp13, 0xH57F0
- %tmp16 = getelementptr inbounds half, ptr %arg, i64 %tmp6
- store half %tmp14, ptr %tmp16, align 8
- %tmp17 = getelementptr inbounds half, ptr %arg1, i64 %tmp7
- %tmp18 = load half, ptr %tmp17, align 2
- %tmp19 = fmul fast half %tmp18, 0xH5380
- %tmp20 = fadd fast half %tmp19, 0xH57F0
- %tmp21 = getelementptr inbounds half, ptr %arg, i64 %tmp7
- store half %tmp20, ptr %tmp21, align 2
+ %1 = shl nuw nsw i32 %arg2, 6
+ %4 = or i32 %1, %arg3
+ %5 = shl nuw nsw i32 %4, 2
+ %6 = zext i32 %5 to i64
+ %7 = or disjoint i64 %6, 1
+ %11 = getelementptr inbounds half, ptr %arg1, i64 %6
+ %12 = load half, ptr %11, align 8
+ %13 = fmul fast half %12, 0xH5380
+ %14 = fadd fast half %13, 0xH57F0
+ %16 = getelementptr inbounds half, ptr %arg, i64 %6
+ store half %14, ptr %16, align 8
+ %17 = getelementptr inbounds half, ptr %arg1, i64 %7
+ %18 = load half, ptr %17, align 2
+ %19 = fmul fast half %18, 0xH5380
+ %20 = fadd fast half %19, 0xH57F0
+ %21 = getelementptr inbounds half, ptr %arg, i64 %7
+ store half %20, ptr %21, align 2
ret void
}
+define ptx_kernel void @add_f16(ptr addrspace(1) %0, { half, half } %1, { half, half } %2) {
+; VECTOR-LABEL: @add_f16(
+; VECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; VECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; VECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; VECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; VECTOR-NEXT: [[TMP8:%.*]] = insertelement <2 x half> poison, half [[TMP4]], i32 0
+; VECTOR-NEXT: [[TMP9:%.*]] = insertelement <2 x half> [[TMP8]], half [[TMP5]], i32 1
+; VECTOR-NEXT: [[TMP10:%.*]] = insertelement <2 x half> poison, half [[TMP6]], i32 0
+; VECTOR-NEXT: [[TMP11:%.*]] = insertelement <2 x half> [[TMP10]], half [[TMP7]], i32 1
+; VECTOR-NEXT: [[TMP12:%.*]] = fadd <2 x half> [[TMP9]], [[TMP11]]
+; VECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; VECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; VECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; VECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; VECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; VECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; VECTOR-NEXT: ret void
+;
+; NOVECTOR-LABEL: @add_f16(
+; NOVECTOR-NEXT: [[TMP4:%.*]] = extractvalue { half, half } [[TMP1:%.*]], 0
+; NOVECTOR-NEXT: [[TMP5:%.*]] = extractvalue { half, half } [[TMP1]], 1
+; NOVECTOR-NEXT: [[TMP6:%.*]] = extractvalue { half, half } [[TMP2:%.*]], 0
+; NOVECTOR-NEXT: [[TMP7:%.*]] = extractvalue { half, half } [[TMP2]], 1
+; NOVECTOR-NEXT: [[TMP8:%.*]] = fadd half [[TMP4]], [[TMP6]]
+; NOVECTOR-NEXT: [[TMP9:%.*]] = fadd half [[TMP5]], [[TMP7]]
+; NOVECTOR-NEXT: [[TMP13:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+; NOVECTOR-NEXT: [[TMP14:%.*]] = shl i32 [[TMP13]], 1
+; NOVECTOR-NEXT: [[TMP15:%.*]] = and i32 [[TMP14]], 62
+; NOVECTOR-NEXT: [[TMP16:%.*]] = zext nneg i32 [[TMP15]] to i64
+; NOVECTOR-NEXT: [[TMP17:%.*]] = getelementptr half, ptr addrspace(1) [[TMP0:%.*]], i64 [[TMP16]]
+; NOVECTOR-NEXT: [[TMP19:%.*]] = insertelement <2 x half> poison, half [[TMP8]], i64 0
+; NOVECTOR-NEXT: [[TMP12:%.*]] = insertelement <2 x half> [[TMP19]], half [[TMP9]], i64 1
+; NOVECTOR-NEXT: store <2 x half> [[TMP12]], ptr addrspace(1) [[TMP17]], align 4
+; NOVECTOR-NEXT: ret void
+;
+ %5 = extractvalue { half, half } %1, 0
+ %6 = extractvalue { half, half } %1, 1
+ %7 = extractvalue { half, half } %2, 0
+ %8 = extractvalue { half, half } %2, 1
+ %9 = fadd half %5, %7
+ %10 = fadd half %6, %8
+ %11 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+ %12 = shl i32 %11, 1
+ %13 = and i32 %12, 62
+ %14 = zext nneg i32 %13 to i64
+ %15 = getelementptr half, ptr addrspace(1) %0, i64 %14
+ %18 = insertelement <2 x half> poison, half %9, i64 0
+ %19 = insertelement <2 x half> %18, half %10, i64 1
+ store <2 x half> %19, ptr addrspace(1) %15, align 4
+ ret void
+}
+
+; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
+declare noundef i32 @llvm.nvvm.read.ptx.sreg.tid.x() #1
+
attributes #0 = { nounwind }
+attributes #1 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
+;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
+; SM50: {{.*}}
+; SM70: {{.*}}
+; SM80: {{.*}}
+; SM90: {{.*}}
|
Single instruction in PTX does not mean that it's efficiently implemented in hardware. In this particular case, I think that the current estimate that construction of a v2f16 vector costs us an equivalent of few logical ops is quite reasonable. |
|
I think your example is a bit misleading because it includes the argument passing convention, if we read the values from gmem instead the mov becomes a single SASS instruction: I am a bit surprised though that ptxas keeps those extra |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not very familiar with this API. Could you explain why we need to incur the cost of the zext? Doesn't prmt emit an i32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We zext the inputs to the 2 first PRMT ops, see the relevant lowering here:
https://github.com/llvm/llvm-project/blob/b5bbe4eef3823facf83e85d2c11a97ce01882ea2/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp#L2147-L2165
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It occurs to me that we don't actually need the zero- part we really just need to extend the register type, but unfortunately I can't think of a way to really express that in the ptx. Certainly the current lowering sometimes produces SASS to zero the top part of the register. e.g. see the LOP3 instruction in this:
https://godbolt.org/z/EPjjzv4cz
It seems for the load ptxas knows the top part is zeroed, but if it's unsure it will mask the lower bits.
Fair enough. It's still not free as it needs prmt/xmad. I think that cost adjust ment here is a wash. In the end it will end up with about the same instructions on the SASS level, whether we let LLVM construct it by element insertion or by logical ops -- there's just no way to bypass the fact that all registers on SASS level are 32-bit, so we end up shuffling bits, one way or another. Perhaps I'm missing something? Can you compile the test cases you've added to |
The current heuristic results in a cost of 2 to build
Here are the results for PTX for sm_90SASS for sm_90The SASS is a bit confusing, but I think |
|
Here's a more apples-to-apples comparison, where I literally replaced Afaict, the generated code is almost identical, modulo different order of instructions. |
|
@Artem-B I don't follow your point. Yes, |
|
Or do you mean that we should just rely on |
|
Sorry, I've got tunnel-visioned into looking only at the conversion to/from f16x2 vectors only. That part will probably remain the same regardless of how we do it on PTX level. However, if the change does allow us to vectorize more scalar 16-bit ops, that is useful. |
We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single `mov` instruction that can build <2 x half> vectors from scalars, however the SLPVectorizer estimates the cost as 2 insert elements. To fix this I add `TargetTransformInfo::getBuildVectorCost` so the target can optionally specify the exact cost.
e75affb to
52f75fc
Compare
We've observed that the SLPVectorizer is too conservative on NVPTX because it over-estimates the cost to build a vector. PTX has a single
movinstruction that can build e.g.<2 x half>vectors from scalars, however the SLPVectorizer over-estimates it as the cost of 2 insert elements.To fix this I customize
getScalarizationOverheadto lower the cost for building 2x16 types.